{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.429178Z",
     "start_time": "2025-09-02T02:55:05.410293Z"
    }
   },
   "source": [
    "import gymnasium as gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.10863715,  0.9940815 , -0.425623  ], dtype=float32)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 25
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.510014Z",
     "start_time": "2025-09-02T02:55:05.460503Z"
    }
   },
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjUsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvWftoOwAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHwlJREFUeJzt3X9wlPWdwPHPbn6RX5uQYBJzJEIPR+T4UQWU1JvxpqREy1hRboY6DqaUkZECA+JwI63G0etMGLyr1VZib5wKf1TopCNaKWgz4ZeekR9BLARM1aGSE5MImE2IZvPre/P9wq5ZjLqBkHx2837NPH1293mSbJ6Gfft8n+8mHmOMEQAAFPIO9xMAAODrECkAgFpECgCgFpECAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWsMWqWeffVbGjRsno0aNkptvvln2798/XE8FAKDUsETqj3/8o6xevVoee+wxOXTokEybNk1KSkqkubl5OJ4OAEApz3D8gll75jRz5kz57W9/6+739vZKQUGBrFixQh5++OGhfjoAAKXih/oLdnZ2Sm1traxduzb0mNfrleLiYqmpqen3YwKBgFuCbNTOnj0r2dnZ4vF4huR5AwAGjz0/amtrk/z8fNcANZE6ffq09PT0SG5ubtjj9v57773X78eUl5fL448/PkTPEAAwVBoaGmTs2LF6InUp7FmXvYYV5Pf7pbCw0H1zPp9vWJ8bAGDgWltb3WWe9PT0b9xvyCM1ZswYiYuLk6amprDH7f28vLx+PyYpKcktF7OBIlIAEL2+7ZLNkM/uS0xMlOnTp0t1dXXYNSZ7v6ioaKifDgBAsWEZ7rNDd6WlpTJjxgy56aab5Ne//rW0t7fLokWLhuPpAACUGpZILViwQD799FMpKyuTxsZG+e53vyuvvfbaVyZTAABGtmF5n9RgXHDLyMhwEyi4JgUA0SfS13F+dx8AQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBAGInUnv37pU77rhD8vPzxePxyMsvvxy23RgjZWVlcvXVV0tycrIUFxfL+++/H7bP2bNn5d577xWfzyeZmZmyePFiOXfu3OV/NwCAkR2p9vZ2mTZtmjz77LP9bl+/fr0888wz8txzz8m+ffskNTVVSkpKpKOjI7SPDVRdXZ1UVVXJtm3bXPiWLFlyed8JACD2mMtgP3zr1q2h+729vSYvL888+eSTocdaWlpMUlKS2bx5s7t/7Ngx93EHDhwI7bNjxw7j8XjMxx9/HNHX9fv97nPYNQAg+kT6Oj6o16ROnDghjY2NbogvKCMjQ26++Wapqalx9+3aDvHNmDEjtI/d3+v1ujOv/gQCAWltbQ1bAACxb1AjZQNl5ebmhj1u7we32XVOTk7Y9vj4eMnKygrtc7Hy8nIXu+BSUFAwmE8bAKBUVMzuW7t2rfj9/tDS0NAw3E8JABBtkcrLy3PrpqamsMft/eA2u25ubg7b3t3d7Wb8Bfe5WFJSkpsJ2HcBAMS+QY3U+PHjXWiqq6tDj9nrR/ZaU1FRkbtv1y0tLVJbWxvaZ+fOndLb2+uuXQEAEBQvA2Tfz/TBBx+ETZY4fPiwu6ZUWFgoq1atkl/+8pdy7bXXumg9+uij7j1V8+bNc/tff/31ctttt8n999/vpql3dXXJ8uXL5cc//rHbDwCAkIFOG9y1a5ebNnjxUlpaGpqG/uijj5rc3Fw39Xz27Nmmvr4+7HOcOXPG3HPPPSYtLc34fD6zaNEi09bWNuhTFwEAOkX6Ou6x/yNRxg4h2ll+dhIF16cAIPpE+joeFbP7AAAjE5ECAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgCgFpECAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgCgFpECAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgCgFpECAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgCgFpECAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgCgFpECAKhFpAAAsRGp8vJymTlzpqSnp0tOTo7MmzdP6uvrw/bp6OiQZcuWSXZ2tqSlpcn8+fOlqakpbJ+TJ0/K3LlzJSUlxX2eNWvWSHd39+B8RwCAkRmpPXv2uAC9/fbbUlVVJV1dXTJnzhxpb28P7fPggw/Kq6++KpWVlW7/U6dOyd133x3a3tPT4wLV2dkpb731lmzatEk2btwoZWVlg/udAQCin7kMzc3Nxn6KPXv2uPstLS0mISHBVFZWhvY5fvy426empsbd3759u/F6vaaxsTG0T0VFhfH5fCYQCET0df1+v/ucdg0AiD6Rvo5f1jUpv9/v1llZWW5dW1vrzq6Ki4tD+0ycOFEKCwulpqbG3bfrKVOmSG5ubmifkpISaW1tlbq6un6/TiAQcNv7LgCA2HfJkert7ZVVq1bJLbfcIpMnT3aPNTY2SmJiomRmZobta4NktwX36Ruo4Pbgtq+7FpaRkRFaCgoKLvVpAwBGQqTstamjR4/Kli1b5Epbu3atO2sLLg0NDVf8awIAhl/8pXzQ8uXLZdu2bbJ3714ZO3Zs6PG8vDw3IaKlpSXsbMrO7rPbgvvs378/7PMFZ/8F97lYUlKSWwAAI8uAzqSMMS5QW7dulZ07d8r48ePDtk+fPl0SEhKkuro69Jidom6nnBcVFbn7dn3kyBFpbm4O7WNnCvp8Ppk0adLlf0cAgJF5JmWH+F588UV55ZVX3HulgteQ7HWi5ORkt168eLGsXr3aTaaw4VmxYoUL06xZs9y+dsq6jdHChQtl/fr17nM88sgj7nNztgQA6Mtjp/hJhDweT7+Pv/DCC/KTn/wk9Gbehx56SDZv3uxm5dmZexs2bAgbyvvoo49k6dKlsnv3bklNTZXS0lJZt26dxMdH1kw7u88G0V6fsiEEAESXSF/HBxQpLYgUAES3SF/H+d19AAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUCt+uJ8AMFIZYwa0v8fjuWLPBdCKSAFXOkTGiOnu/nLp6XHrno4O6Wppkc6mJuk8fdotHSdPStfZsyJer9sv3ueT1AkTZPS//quk/PM/S1xqKrHCiEKkgEGKUe/nn0v3559LT3u79Jw759bd9nZrq4tR12efuaXb75fOs2elx+//1s/b9emn8sWHH8qZXbsk48YbJWfePEm7/npChRGDSAGRDsPZEHV2ujOdrjNnXHA67frCbRenL76Q3i++kB4bq44Od9t0dV3+c+rslJa335aOTz6RwiVLJG3yZEKFEYFIQUb6MFxvV5cLSfB2byAg3a2t0vnpp9LZ3Hx+KM7ebmqS7rY2kd5eMRcW6bMeCh0ffSQn/+d/pHDpUs6oMCIQKcT+MFwgID1tbW7ozcYnNBRnh+H8fjf8Zofj7Dq42GANl4/b2+Wds2elratLrho1SoquukpSExLCQvXJ5s3ynf/4D4lPTx+25wkMBSKF6ByGC96+sO7t7pZuO/x2+rQbjnNnP3YY7vTpL4fhAgHp7eg4PwzX0eGG0LR9fyfOnZPH3nlH/nHunHT09IgvIUEmjx4t/zVzpiR4v3zHSNu778rnH34o6dOmcTaFmOYxA50Hq0Bra6tkZGSI3+8Xn8833E8Hg8z+SNrhN3v9xy42JsHb9owoNPx2YUZcwA7DffbZ+eE7O+xm1/bH+sLtaPFhW5ss+d//FX8/17BuGjNG/vOGGyR71KjQY0n5+fIvFRVEClEp0tdxzqQwbGx07DUeNxTX1uaG3uxQnBtya211j7mhODtEd2G7CQQkVv26rq7fQFn7T5+WqlOn5Mff+U7oscGYkAFoR6QwqBMRgktoYkJPj4uNO/O5MPxmb9uzHxskNwRnz5LsOhA4f9YUwyECMDBECpfFBsUNuTU2ugi5a0J2sbc//dRdH7IRCoXLCsYMAL4FkcIlsdd+Pv/gAzn7xhty7uhRdxEfl2duQYEctIHvJ+Dj0tJkalbWsDwvYDgRKQyYnZ59etcuadyyxZ09DdV7hGJdSX6+W//y3Xels6dH7FGN83gkMzFR/nvmTLkmLS1s/9z584fpmQJDh0hhQOz07dOvv+7ep2N/qwIGj52lZ0M1NiVFtv3f/8mZjg53BrVg/HjJTkoK2zfp6qtldFERM/sQ84gUBjTEZ9+f0/jSSwTqCrHRse+LssvXiR89WvIXLnS/fBaIdfw9KUTM/h66k889d/49SRgWcWlpcvWCBZJ5003iiYsb7qcDXHGcSSFiZ3bvPv9nJDDkbJBGFRRI7l13Sda//RvDfBgxiBQidnb3bqaODzFvSookX3ONZMyY4ZbkwkIChRGFSAEKeOLj3dmSG8K7sLZLzty5MmbOHBcrbzz/XDHy8FMPDIW4OIlLSTm/JCe76ARv2yUhO1viMzMl4cISn5Hh1t7kZM6cMKIRKWAQz4bis7LOh8aubWhGj3ZLXHr6+SDZM6JRo86H6kKgPAkJhAj4GkQKEbP/dT+iht8uWuxwm/fCWU9idvb59Zgx59dZWW5KeGi4zut1S/C2eDyECLgERAoRG7tokfhra0V6eiRmht9SU906Pi3ty/tpaS7INjr2bKjvbTf81ufvOgG4sogUIpaYmytjiovdb5xQq8/Zih1Gc0Nu9qyn7zCcPeuxw2926C0pyQ2/BZe4pCTxJCZy1gMoQaQQMTuUlffv/y5dn30m/oMHh+539tmhMjvUZq/dXFi8fdaJOTnn4xMcfsvKcsNxdiKCi01w6M1+ngtrht+A6ECkEDH7op6Umyv/dN997v1S/gMHBvXzu8kFqannh97S0yU+OPRmH09PlwSf7/zjdrFDcBfOhhh+A2IXkcKA2TeUFixZ4v58+Zldu9xfze33Tb4XzljcWcuF4TcbFzfpwM56s2c/9szH3r5wvcebmHh+CK7Pwuw3YOTymNBfoosera2tkpGRIX6/X3z8ks1h4X5senvdX9j1798v5+rr3Z96t4+5a0HZ2e43Jbjht9Gj3drNfrNvSO073BZcLpypARgZWiN8HedMCpfEBcX+Prn8fBk1b57kDvcTAhCTGMwHAKhFpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgCgFpECAKhFpAAAahEpAEBsRKqiokKmTp3q/kCVXYqKimTHjh2h7R0dHbJs2TLJzs6WtLQ0mT9/vjQ1NYV9jpMnT8rcuXMlJSVFcnJyZM2aNdLd3T143xEAYGRGauzYsbJu3Tqpra2VgwcPyve//3258847pa6uzm1/8MEH5dVXX5XKykrZs2ePnDp1Su6+++7Qx/f09LhAdXZ2yltvvSWbNm2SjRs3SllZ2eB/ZwCA6Gcu0+jRo83zzz9vWlpaTEJCgqmsrAxtO378uP3T9Kampsbd3759u/F6vaaxsTG0T0VFhfH5fCYQCET8Nf1+v/u8dg0AiD6Rvo5f8jUpe1a0ZcsWaW9vd8N+9uyqq6tLiouLQ/tMnDhRCgsLpaamxt236ylTpkhu7pd/bLykpMT9rfvg2Vh/AoGA26fvAgCIfQOO1JEjR9z1pqSkJHnggQdk69atMmnSJGlsbJTExETJzMwM298GyW6z7LpvoILbg9u+Tnl5uWRkZISWgoKCgT5tAMBIiNR1110nhw8fln379snSpUultLRUjh07JlfS2rVrxe/3h5aGhoYr+vUAADrED/QD7NnShAkT3O3p06fLgQMH5Omnn5YFCxa4CREtLS1hZ1N2dl9eXp67bdf79+8P+3zB2X/Bffpjz9rsAgAYWS77fVK9vb3umpENVkJCglRXV4e21dfXuynn9pqVZdd2uLC5uTm0T1VVlZvObocMAQC45DMpO+x2++23u8kQbW1t8uKLL8ru3bvl9ddfd9eKFi9eLKtXr5asrCwXnhUrVrgwzZo1y338nDlzXIwWLlwo69evd9ehHnnkEffeKs6UAACXFSl7BnTffffJJ5984qJk39hrA/WDH/zAbX/qqafE6/W6N/Hasys7c2/Dhg2hj4+Li5Nt27a5a1k2Xqmpqe6a1hNPPDGQpwEAGCE8dh66RBk7Bd1G0k6isGdsAIDYfB3nd/cBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUACA2I7Vu3TrxeDyyatWq0GMdHR2ybNkyyc7OlrS0NJk/f740NTWFfdzJkydl7ty5kpKSIjk5ObJmzRrp7u6+nKcCAIhBlxypAwcOyO9+9zuZOnVq2OMPPvigvPrqq1JZWSl79uyRU6dOyd133x3a3tPT4wLV2dkpb731lmzatEk2btwoZWVll/edAABij7kEbW1t5tprrzVVVVXm1ltvNStXrnSPt7S0mISEBFNZWRna9/jx48Z+mZqaGnd/+/btxuv1msbGxtA+FRUVxufzmUAgENHX9/v97nPaNQAg+kT6On5JZ1J2OM+eDRUXF4c9XltbK11dXWGPT5w4UQoLC6Wmpsbdt+spU6ZIbm5uaJ+SkhJpbW2Vurq6fr9eIBBw2/suAIDYFz/QD9iyZYscOnTIDfddrLGxURITEyUzMzPscRskuy24T99ABbcHt/WnvLxcHn/88YE+VQBAlBvQmVRDQ4OsXLlS/vCHP8ioUaNkqKxdu1b8fn9osc8DABD7BhQpO5zX3NwsN954o8THx7vFTo545pln3G17RmQnRLS0tIR9nJ3dl5eX527b9cWz/YL3g/tcLCkpSXw+X9gCAIh9A4rU7Nmz5ciRI3L48OHQMmPGDLn33ntDtxMSEqS6ujr0MfX19W7KeVFRkbtv1/Zz2NgFVVVVufBMmjRpML83AMBIuiaVnp4ukydPDnssNTXVvScq+PjixYtl9erVkpWV5cKzYsUKF6ZZs2a57XPmzHExWrhwoaxfv95dh3rkkUfcZAx7xgQAwCVPnPg2Tz31lHi9XvcmXjsrz87c27BhQ2h7XFycbNu2TZYuXeriZSNXWloqTzzxxGA/FQBAlPPYeegSZewU9IyMDDeJgutTABB9In0d53f3AQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1CJSAAC1iBQAQC0iBQBQi0gBANQiUgAAtYgUAEAtIgUAUItIAQDUIlIAALWIFABALSIFAFCLSAEA1IqXKGSMcevW1tbhfioAgEsQfP0Ovp7HVKTOnDnj1gUFBcP9VAAAl6GtrU0yMjJiK1JZWVluffLkyW/85kY6+18qNuQNDQ3i8/mG++moxXGKDMcpMhynyNgzKBuo/Pz8b9wvKiPl9Z6/lGYDxQ/Bt7PHiOP07ThOkeE4RYbj9O0iOclg4gQAQC0iBQBQKyojlZSUJI899phb4+txnCLDcYoMxykyHKfB5THfNv8PAIBhEpVnUgCAkYFIAQDUIlIAALWIFABAraiM1LPPPivjxo2TUaNGyc033yz79++XkWTv3r1yxx13uHdqezweefnll8O227kwZWVlcvXVV0tycrIUFxfL+++/H7bP2bNn5d5773VvNszMzJTFixfLuXPnJFaUl5fLzJkzJT09XXJycmTevHlSX18ftk9HR4csW7ZMsrOzJS0tTebPny9NTU1h+9jfajJ37lxJSUlxn2fNmjXS3d0tsaKiokKmTp0aeuNpUVGR7NixI7SdY9S/devWuX97q1atCj3GsbpCTJTZsmWLSUxMNL///e9NXV2duf/++01mZqZpamoyI8X27dvNL37xC/PSSy/ZmZlm69atYdvXrVtnMjIyzMsvv2zeffdd86Mf/ciMHz/efPHFF6F9brvtNjNt2jTz9ttvmzfeeMNMmDDB3HPPPSZWlJSUmBdeeMEcPXrUHD582Pzwhz80hYWF5ty5c6F9HnjgAVNQUGCqq6vNwYMHzaxZs8z3vve90Pbu7m4zefJkU1xcbN555x133MeMGWPWrl1rYsWf//xn85e//MX8/e9/N/X19ebnP/+5SUhIcMfN4hh91f79+824cePM1KlTzcqVK0OPc6yujKiL1E033WSWLVsWut/T02Py8/NNeXm5GYkujlRvb6/Jy8szTz75ZOixlpYWk5SUZDZv3uzuHzt2zH3cgQMHQvvs2LHDeDwe8/HHH5tY1Nzc7L7nPXv2hI6JfTGurKwM7XP8+HG3T01NjbtvX0S8Xq9pbGwM7VNRUWF8Pp8JBAImVo0ePdo8//zzHKN+tLW1mWuvvdZUVVWZW2+9NRQpjtWVE1XDfZ2dnVJbW+uGr/r+Hj97v6amZlifmxYnTpyQxsbGsGNkfz+WHRYNHiO7tkN8M2bMCO1j97fHct++fRKL/H5/2C8ntj9HXV1dYcdp4sSJUlhYGHacpkyZIrm5uaF9SkpK3C8Qraurk1jT09MjW7Zskfb2djfsxzH6KjucZ4fr+h4Ti2N15UTVL5g9ffq0+4fU9/9ky95/7733hu15aWIDZfV3jILb7NqOh/cVHx/vXsCD+8SS3t5ed+3glltukcmTJ7vH7PeZmJjoYv1Nx6m/4xjcFiuOHDniomSvqdhrKVu3bpVJkybJ4cOHOUZ92IAfOnRIDhw48JVt/DxdOVEVKeBS/+v36NGj8uabbw73U1Hpuuuuc0GyZ5t/+tOfpLS0VPbs2TPcT0sV+2c3Vq5cKVVVVW7CFoZOVA33jRkzRuLi4r4yY8bez8vLG7bnpUnwOHzTMbLr5ubmsO12hpGd8Rdrx3H58uWybds22bVrl4wdOzb0uP0+7fBxS0vLNx6n/o5jcFussGcAEyZMkOnTp7tZkdOmTZOnn36aY3TRcJ79N3PjjTe6UQe72JA/88wz7rY9I+JYXRneaPvHZP8hVVdXhw3l2Pt2uAIi48ePdz/wfY+RHfO215qCx8iu7T8m+w8vaOfOne5Y2mtXscDOKbGBskNX9nuzx6Uv+3OUkJAQdpzsFHU7RbjvcbJDYX2Dbv9L2k7VtsNhscr+HAQCAY5RH7Nnz3bfpz3jDC72mq59G0fwNsfqCjFROAXdzlTbuHGjm6W2ZMkSNwW974yZWGdnGNkprHax/xf+6le/crc/+uij0BR0e0xeeeUV87e//c3ceeed/U5Bv+GGG8y+ffvMm2++6WYsxdIU9KVLl7pp+Lt37zaffPJJaPn888/Dpgzbaek7d+50U4aLiorccvGU4Tlz5rhp7K+99pq56qqrYmrK8MMPP+xmPJ44ccL9rNj7dpbnX//6V7edY/T1+s7uszhWV0bURcr6zW9+434Y7Pul7JR0+16fkWTXrl0uThcvpaWloWnojz76qMnNzXVBnz17tnsPTF9nzpxxUUpLS3NTYBctWuTiFyv6Oz52se+dCrLR/tnPfuamXKekpJi77rrLhayvf/zjH+b22283ycnJ7j0tDz30kOnq6jKx4qc//am55ppr3L8l+4Jpf1aCgbI4RpFHimN1ZfCnOgAAakXVNSkAwMhCpAAAahEpAIBaRAoAoBaRAgCoRaQAAGoRKQCAWkQKAKAWkQIAqEWkAABqESkAgFpECgAgWv0/wbtcQ80lNcQAAAAASUVORK5CYII="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 26
  },
  {
   "cell_type": "code",
   "metadata": {
    "scrolled": false,
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.524306Z",
     "start_time": "2025-09-02T02:55:05.520957Z"
    }
   },
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用3个数字表示,我也不知道这3个数字分别是什么意思,反正这3个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [-0.91304934 -0.40784913  0.271098  ]\n",
    "\n",
    "    print('这个游戏的动作是个-2到+2之间的连续值')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Box(-2.0, 2.0, (1,), float32)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= [-0.14946985]\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [-0.5629868  0.8264659  2.7232552]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= -4.456876123969679\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用3个数字表示,我也不知道这3个数字分别是什么意思,反正这3个数字就能描述游戏全部的状态\n",
      "state= [-0.17791817 -0.98404527  0.26085824]\n",
      "这个游戏的动作是个-2到+2之间的连续值\n",
      "env.action_space= Box(-2.0, 2.0, (1,), float32)\n",
      "随机一个动作\n",
      "action= [-0.04090903]\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n",
      "state= [-0.20164396 -0.97945887 -0.4833121 ]\n",
      "reward= -3.068140232064126\n",
      "over= False\n"
     ]
    }
   ],
   "execution_count": 27
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.543282Z",
     "start_time": "2025-09-02T02:55:05.539529Z"
    }
   },
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 28
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.564946Z",
     "start_time": "2025-09-02T02:55:05.560627Z"
    }
   },
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(11))\n",
    "\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /= 10\n",
    "    action_continuous *= 4\n",
    "    action_continuous -= 2\n",
    "\n",
    "    return action, action_continuous\n",
    "\n",
    "\n",
    "get_action([0.29292667, 0.9561349, 1.0957013])"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, -0.3999999999999999)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 29
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.595382Z",
     "start_time": "2025-09-02T02:55:05.581626Z"
    }
   },
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 5000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 30
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.621143Z",
     "start_time": "2025-09-02T02:55:05.614231Z"
    }
   },
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.9937, -0.1125, -2.9603],\n",
       "         [-0.6362,  0.7715, -7.5578],\n",
       "         [ 0.6992,  0.7149, -3.8732],\n",
       "         [-0.8933,  0.4494, -7.7294],\n",
       "         [-0.8933,  0.4495, -7.9484],\n",
       "         [-0.2200,  0.9755, -6.7244],\n",
       "         [ 0.4525,  0.8917, -4.8708],\n",
       "         [ 0.9856,  0.1688, -2.5134],\n",
       "         [ 0.9864, -0.1644, -2.9354],\n",
       "         [-0.6254, -0.7803, -7.5169],\n",
       "         [-0.9955, -0.0950, -8.0000],\n",
       "         [ 0.7868, -0.6172, -0.8564],\n",
       "         [ 0.9999,  0.0109, -2.6022],\n",
       "         [-0.9968,  0.0795, -7.7891],\n",
       "         [-0.2723, -0.9622, -6.7447],\n",
       "         [ 0.9450, -0.3270, -3.3587],\n",
       "         [ 0.8773, -0.4800, -3.8376],\n",
       "         [ 0.9883, -0.1523, -2.9154],\n",
       "         [ 0.9101,  0.4143, -2.0001],\n",
       "         [ 0.7580, -0.6522, -4.1975],\n",
       "         [ 0.5978, -0.8017, -4.6490],\n",
       "         [ 0.3336,  0.9427, -5.2451],\n",
       "         [ 0.9614, -0.2752, -3.3121],\n",
       "         [-0.9926, -0.1214, -8.0000],\n",
       "         [-0.6706,  0.7418, -7.3924],\n",
       "         [ 0.2095, -0.9778, -5.5494],\n",
       "         [-0.2257, -0.9742, -6.6564],\n",
       "         [ 0.4694, -0.8830, -4.8872],\n",
       "         [ 0.9527, -0.3039, -3.3097],\n",
       "         [ 0.8715, -0.4905, -3.8651],\n",
       "         [-0.9575, -0.2885, -8.0000],\n",
       "         [ 0.3168,  0.9485, -5.2949],\n",
       "         [ 0.2242,  0.9746, -5.6017],\n",
       "         [ 0.1765, -0.9843, -5.5838],\n",
       "         [-0.7850,  0.6194, -7.7939],\n",
       "         [-0.3340,  0.9426, -6.9792],\n",
       "         [-0.9745,  0.2243, -8.0000],\n",
       "         [-0.2953, -0.9554, -6.8004],\n",
       "         [-0.5110,  0.8596, -7.3294],\n",
       "         [ 0.3361, -0.9418, -5.3446],\n",
       "         [ 0.9937,  0.1125, -2.4060],\n",
       "         [ 0.8827, -0.4700, -3.8147],\n",
       "         [ 0.5555, -0.8315, -4.7890],\n",
       "         [-0.7690,  0.6393, -7.7749],\n",
       "         [ 0.7490, -0.6626, -4.2329],\n",
       "         [ 0.8131, -0.5821, -3.8931],\n",
       "         [-0.8152,  0.5792, -7.8318],\n",
       "         [-0.9849, -0.1729, -8.0000],\n",
       "         [ 0.5859, -0.8104, -4.6867],\n",
       "         [ 0.8642,  0.5031, -2.3774],\n",
       "         [-0.1794,  0.9838, -6.6418],\n",
       "         [ 0.1827,  0.9832, -5.3961],\n",
       "         [ 0.5240,  0.8517, -4.5835],\n",
       "         [ 0.9901,  0.1405, -2.4076],\n",
       "         [ 0.7218,  0.6921, -3.7981],\n",
       "         [ 0.6691, -0.7432, -4.3298],\n",
       "         [ 0.3901, -0.9208, -5.2200],\n",
       "         [-0.4390, -0.8985, -7.0288],\n",
       "         [ 0.8153,  0.5790, -3.3105],\n",
       "         [ 0.9978,  0.0661, -2.7288],\n",
       "         [ 0.9913,  0.1317, -2.4027],\n",
       "         [ 0.6459, -0.7634, -4.3381],\n",
       "         [-0.8670, -0.4983, -8.0000],\n",
       "         [ 0.9994,  0.0353, -2.6868]]),\n",
       " tensor([[0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [5],\n",
       "         [0],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0],\n",
       "         [0],\n",
       "         [5],\n",
       "         [5],\n",
       "         [0]]),\n",
       " tensor([[ -0.8930],\n",
       "         [-10.8213],\n",
       "         [ -2.1346],\n",
       "         [-13.1327],\n",
       "         [-13.4753],\n",
       "         [ -7.7352],\n",
       "         [ -3.5850],\n",
       "         [ -0.6645],\n",
       "         [ -0.8929],\n",
       "         [-10.6969],\n",
       "         [-15.6807],\n",
       "         [ -0.5198],\n",
       "         [ -0.6813],\n",
       "         [-15.4427],\n",
       "         [ -7.9590],\n",
       "         [ -1.2430],\n",
       "         [ -1.7233],\n",
       "         [ -0.8773],\n",
       "         [ -0.5865],\n",
       "         [ -2.2667],\n",
       "         [ -3.0264],\n",
       "         [ -4.2657],\n",
       "         [ -1.1787],\n",
       "         [-15.5199],\n",
       "         [-10.7818],\n",
       "         [ -4.9284],\n",
       "         [ -7.6651],\n",
       "         [ -3.5596],\n",
       "         [ -1.1947],\n",
       "         [ -1.7567],\n",
       "         [-14.5166],\n",
       "         [ -4.3623],\n",
       "         [ -4.9462],\n",
       "         [ -5.0593],\n",
       "         [-12.1931],\n",
       "         [ -8.5243],\n",
       "         [-14.8996],\n",
       "         [ -8.1236],\n",
       "         [ -9.8118],\n",
       "         [ -4.3645],\n",
       "         [ -0.5956],\n",
       "         [ -1.6946],\n",
       "         [ -3.2574],\n",
       "         [-12.0377],\n",
       "         [ -2.3164],\n",
       "         [ -1.9018],\n",
       "         [-12.5034],\n",
       "         [-15.2077],\n",
       "         [ -3.0892],\n",
       "         [ -0.8431],\n",
       "         [ -7.4779],\n",
       "         [ -4.8358],\n",
       "         [ -3.1398],\n",
       "         [ -0.6035],\n",
       "         [ -2.0269],\n",
       "         [ -2.5767],\n",
       "         [ -4.0939],\n",
       "         [ -9.0421],\n",
       "         [ -1.4773],\n",
       "         [ -0.7530],\n",
       "         [ -0.5988],\n",
       "         [ -2.6364],\n",
       "         [-13.2640],\n",
       "         [ -0.7271]]),\n",
       " tensor([[ 9.6107e-01, -2.7630e-01, -3.3447e+00],\n",
       "         [-3.3404e-01,  9.4256e-01, -6.9792e+00],\n",
       "         [ 8.0821e-01,  5.8889e-01, -3.3370e+00],\n",
       "         [-6.7065e-01,  7.4178e-01, -7.3924e+00],\n",
       "         [-6.6239e-01,  7.4916e-01, -7.6113e+00],\n",
       "         [ 7.7764e-02,  9.9697e-01, -5.9928e+00],\n",
       "         [ 6.2857e-01,  7.7775e-01, -4.2020e+00],\n",
       "         [ 9.9938e-01,  3.5293e-02, -2.6868e+00],\n",
       "         [ 9.4504e-01, -3.2696e-01, -3.3587e+00],\n",
       "         [-8.7989e-01, -4.7518e-01, -8.0000e+00],\n",
       "         [-9.5390e-01,  3.0013e-01, -8.0000e+00],\n",
       "         [ 7.3431e-01, -6.7881e-01, -1.6193e+00],\n",
       "         [ 9.9106e-01, -1.3344e-01, -2.8941e+00],\n",
       "         [-8.9333e-01,  4.4940e-01, -7.7294e+00],\n",
       "         [-6.0449e-01, -7.9662e-01, -7.4663e+00],\n",
       "         [ 8.6368e-01, -5.0405e-01, -3.9039e+00],\n",
       "         [ 7.5804e-01, -6.5221e-01, -4.1975e+00],\n",
       "         [ 9.4943e-01, -3.1398e-01, -3.3296e+00],\n",
       "         [ 9.4678e-01,  3.2188e-01, -1.9894e+00],\n",
       "         [ 5.8588e-01, -8.1040e-01, -4.6867e+00],\n",
       "         [ 3.6923e-01, -9.2934e-01, -5.2503e+00],\n",
       "         [ 5.3710e-01,  8.4352e-01, -4.5380e+00],\n",
       "         [ 8.9168e-01, -4.5267e-01, -3.8185e+00],\n",
       "         [-9.6152e-01,  2.7474e-01, -8.0000e+00],\n",
       "         [-3.8322e-01,  9.2366e-01, -6.8361e+00],\n",
       "         [-1.0285e-01, -9.9470e-01, -6.2828e+00],\n",
       "         [-5.6214e-01, -8.2704e-01, -7.3871e+00],\n",
       "         [ 2.0954e-01, -9.7780e-01, -5.5494e+00],\n",
       "         [ 8.7728e-01, -4.7997e-01, -3.8376e+00],\n",
       "         [ 7.4897e-01, -6.6260e-01, -4.2329e+00],\n",
       "         [-9.9424e-01,  1.0717e-01, -8.0000e+00],\n",
       "         [ 5.2398e-01,  8.5173e-01, -4.5835e+00],\n",
       "         [ 4.5255e-01,  8.9174e-01, -4.8708e+00],\n",
       "         [-1.3819e-01, -9.9041e-01, -6.3220e+00],\n",
       "         [-5.1096e-01,  8.5960e-01, -7.3294e+00],\n",
       "         [-2.6971e-02,  9.9964e-01, -6.2723e+00],\n",
       "         [-8.1517e-01,  5.7922e-01, -7.8318e+00],\n",
       "         [-6.2539e-01, -7.8031e-01, -7.5169e+00],\n",
       "         [-2.0070e-01,  9.7965e-01, -6.6847e+00],\n",
       "         [ 4.0226e-02, -9.9919e-01, -6.0510e+00],\n",
       "         [ 9.9983e-01, -1.8375e-02, -2.6216e+00],\n",
       "         [ 7.6636e-01, -6.4242e-01, -4.1672e+00],\n",
       "         [ 3.1299e-01, -9.4976e-01, -5.4126e+00],\n",
       "         [-4.9031e-01,  8.7155e-01, -7.2954e+00],\n",
       "         [ 5.7288e-01, -8.1964e-01, -4.7299e+00],\n",
       "         [ 6.6906e-01, -7.4321e-01, -4.3298e+00],\n",
       "         [-5.5066e-01,  8.3473e-01, -7.3974e+00],\n",
       "         [-9.7453e-01,  2.2426e-01, -8.0000e+00],\n",
       "         [ 3.5343e-01, -9.3546e-01, -5.2945e+00],\n",
       "         [ 9.1014e-01,  4.1431e-01, -2.0001e+00],\n",
       "         [ 1.1459e-01,  9.9341e-01, -5.9039e+00],\n",
       "         [ 4.0468e-01,  9.1446e-01, -4.6587e+00],\n",
       "         [ 6.8073e-01,  7.3254e-01, -3.9447e+00],\n",
       "         [ 9.9994e-01,  1.0860e-02, -2.6022e+00],\n",
       "         [ 8.2506e-01,  5.6505e-01, -3.2790e+00],\n",
       "         [ 4.6937e-01, -8.8300e-01, -4.8872e+00],\n",
       "         [ 1.0500e-01, -9.9447e-01, -5.9105e+00],\n",
       "         [-7.4438e-01, -6.6776e-01, -7.7027e+00],\n",
       "         [ 8.8990e-01,  4.5616e-01, -2.8763e+00],\n",
       "         [ 9.9657e-01, -8.2759e-02, -2.9793e+00],\n",
       "         [ 1.0000e+00,  1.8663e-03, -2.6040e+00],\n",
       "         [ 4.4095e-01, -8.9753e-01, -4.9107e+00],\n",
       "         [-9.9261e-01, -1.2138e-01, -8.0000e+00],\n",
       "         [ 9.9365e-01, -1.1248e-01, -2.9603e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 31
  },
  {
   "cell_type": "code",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.637022Z",
     "start_time": "2025-09-02T02:55:05.632606Z"
    }
   },
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.2009],\n",
       "        [0.9778],\n",
       "        [0.3856],\n",
       "        [0.9723],\n",
       "        [1.0067],\n",
       "        [0.8871],\n",
       "        [0.5784],\n",
       "        [0.1894],\n",
       "        [0.2091],\n",
       "        [0.8506],\n",
       "        [0.9757],\n",
       "        [0.3394],\n",
       "        [0.2113],\n",
       "        [0.9533],\n",
       "        [0.7068],\n",
       "        [0.1950],\n",
       "        [0.2408],\n",
       "        [0.2093],\n",
       "        [0.1939],\n",
       "        [0.2859],\n",
       "        [0.3548],\n",
       "        [0.6476],\n",
       "        [0.1933],\n",
       "        [0.9741],\n",
       "        [0.9500],\n",
       "        [0.5032],\n",
       "        [0.6910],\n",
       "        [0.3948],\n",
       "        [0.1966],\n",
       "        [0.2437],\n",
       "        [0.9614],\n",
       "        [0.6567],\n",
       "        [0.7116],\n",
       "        [0.5101],\n",
       "        [0.9978],\n",
       "        [0.9182],\n",
       "        [0.9951],\n",
       "        [0.7163],\n",
       "        [0.9549],\n",
       "        [0.4663],\n",
       "        [0.2115],\n",
       "        [0.2385],\n",
       "        [0.3772],\n",
       "        [0.9971],\n",
       "        [0.2916],\n",
       "        [0.2396],\n",
       "        [1.0001],\n",
       "        [0.9709],\n",
       "        [0.3609],\n",
       "        [0.1521],\n",
       "        [0.8766],\n",
       "        [0.6790],\n",
       "        [0.5233],\n",
       "        [0.2057],\n",
       "        [0.3716],\n",
       "        [0.3064],\n",
       "        [0.4456],\n",
       "        [0.7621],\n",
       "        [0.2894],\n",
       "        [0.1928],\n",
       "        [0.2080],\n",
       "        [0.3077],\n",
       "        [0.9440],\n",
       "        [0.2009]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 32
  },
  {
   "cell_type": "code",
   "metadata": {
    "scrolled": true,
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.697809Z",
     "start_time": "2025-09-02T02:55:05.693989Z"
    }
   },
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -0.7034],\n",
       "        [ -9.9215],\n",
       "        [ -1.8467],\n",
       "        [-12.2017],\n",
       "        [-12.5113],\n",
       "        [ -6.9711],\n",
       "        [ -3.1473],\n",
       "        [ -0.4676],\n",
       "        [ -0.7018],\n",
       "        [ -9.7699],\n",
       "        [-14.6991],\n",
       "        [ -0.2125],\n",
       "        [ -0.4768],\n",
       "        [-14.4898],\n",
       "        [ -7.1348],\n",
       "        [ -1.0001],\n",
       "        [ -1.4431],\n",
       "        [ -0.6852],\n",
       "        [ -0.3736],\n",
       "        [ -1.9131],\n",
       "        [ -2.5844],\n",
       "        [ -3.7619],\n",
       "        [ -0.9422],\n",
       "        [-14.5404],\n",
       "        [ -9.9069],\n",
       "        [ -4.3122],\n",
       "        [ -6.8563],\n",
       "        [ -3.0665],\n",
       "        [ -0.9588],\n",
       "        [ -1.4709],\n",
       "        [-13.5476],\n",
       "        [ -3.8495],\n",
       "        [ -4.3793],\n",
       "        [ -4.4346],\n",
       "        [-11.2573],\n",
       "        [ -7.7169],\n",
       "        [-13.9195],\n",
       "        [ -7.2900],\n",
       "        [ -8.9473],\n",
       "        [ -3.7915],\n",
       "        [ -0.3858],\n",
       "        [ -1.4191],\n",
       "        [ -2.7897],\n",
       "        [-11.1048],\n",
       "        [ -1.9560],\n",
       "        [ -1.6015],\n",
       "        [-11.5620],\n",
       "        [-14.2325],\n",
       "        [ -2.6402],\n",
       "        [ -0.6531],\n",
       "        [ -6.7282],\n",
       "        [ -4.3004],\n",
       "        [ -2.7485],\n",
       "        [ -0.3965],\n",
       "        [ -1.7485],\n",
       "        [ -2.1898],\n",
       "        [ -3.5447],\n",
       "        [ -8.1688],\n",
       "        [ -1.2630],\n",
       "        [ -0.7530],\n",
       "        [ -0.3905],\n",
       "        [ -2.2450],\n",
       "        [-12.3093],\n",
       "        [ -0.5303]])"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 33
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:05.759458Z",
     "start_time": "2025-09-02T02:55:05.749871Z"
    }
   },
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        _, action_continuous = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ],
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1032.5145065225263"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 34
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false,
    "ExecuteTime": {
     "end_time": "2025-09-02T02:55:21.904742Z",
     "start_time": "2025-09-02T02:55:05.792419Z"
    }
   },
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 200 0 -1539.4044512652192\n",
      "20 4400 200 0 -1201.7792155360735\n",
      "40 5000 200 200 -493.0457594896164\n",
      "60 5000 200 200 -775.7228825237967\n",
      "80 5000 200 200 -295.15478489202053\n",
      "100 5000 200 200 -198.87222125568323\n",
      "120 5000 200 200 -158.04708298685378\n",
      "140 5000 200 200 -489.29422260717894\n",
      "160 5000 200 200 -190.88758883198295\n",
      "180 5000 200 200 -287.71903958260833\n"
     ]
    }
   ],
   "execution_count": 35
  },
  {
   "cell_type": "code",
   "metadata": {},
   "source": [
    "test(play=True)"
   ],
   "execution_count": 36,
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-127.43501860055763"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
